import os, sys
from load_data import *
import re
from torch.utils.data import Dataset
from datasets import Dataset as HFDataset

def generate_idx_specific_input_output_list(lists, functions, indices):
    # lists are all in input data
    input_list, output_list = [], []
    prompt_function, input_function, output_function = functions

    for i in indices:
        input_list.append(prompt_function(lists, i) + input_function(lists, i))
        output_list.append(output_function(lists, i))

    return input_list, output_list

def generate_rag_input_output_list(lists, functions, indices, context_tensor):
    input_list, output_list = [], []
    prompt_function, input_function, output_function = functions
    assert len(indices) == len(context_tensor)
    for idx, i in enumerate(indices):
        for context in context_tensor[idx]:
            input_list.append(prompt_function(lists, i) + input_function(lists, i, context))
            output_list.append(output_function(lists, i))

    return input_list, output_list

def matching_order(terms, sentence, rev=False):
    term_occurrences = {}
    for term in terms:
        matches = [m.start() for m in re.finditer(term, sentence)]
        if matches:
            term_occurrences[term] = matches
    if rev:
        # Sort the terms based on their last occurrence in the sentence, the later the better
        sorted_terms = sorted(term_occurrences.items(), key=lambda x: x[1][-1], reverse=True)
    else:
        # Sort the terms based on their first occurrence in the sentence, the earlier the better
        sorted_terms = sorted(term_occurrences.items(), key=lambda x: x[1][0])
    return [x[0] for x in sorted_terms] # only return the key

def accuracy(list1, list2):
    assert len(list1) == len(list2)
    hit = 0
    for i in range(len(list1)):
        if list1[i] == list2[i]: hit += 1
    return hit/len(list1)

class Seq2SeqDataset(Dataset):
    def __init__(self, src_sequences, tgt_sequences):
        self.src_sequences = src_sequences
        self.tgt_sequences = tgt_sequences
        assert len(src_sequences) == len(tgt_sequences), "Source and target sequences must have the same length"

    def __len__(self):
        return len(self.src_sequences)

    def __getitem__(self, idx):
        src_seq = self.src_sequences[idx]
        tgt_seq = self.tgt_sequences[idx]
        return src_seq, tgt_seq

def string_tensor(index_tensor, corpus, mask):
    context_tensor = [] # batch x (topk-1)
    for i, row in enumerate(index_tensor):
        string_row = []
        for j, idx in enumerate(row):
            if mask[i][j] == False:
                passage = "None"
            else:
                passage = corpus[idx.item()]  # Retrieve passage using index
            string_row.append(passage)
        context_tensor.append(string_row)
    return context_tensor
    
"""
Early stop adapted from DGL implementation
"""
class EarlyStopping:
    def __init__(self, patience=10):
        self.patience = patience
        self.counter = 0
        self.best_score = float('inf') # the lower the better, in the case of loss
        self.best_epoch = None
        self.early_stop = False

    def step(self, score, hf_models, epoch, output_dir):
        if score >= self.best_score:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

        else:
            self.best_score = score
            self.best_epoch = epoch
            self.save_checkpoint(hf_models, output_dir)
            self.counter = 0
        es_str = f'{self.counter:02d}/Patience:{self.patience:02d} | BestVal={self.best_score:.3f}@Epoch{self.best_epoch:.3f}'
        return self.early_stop, es_str

    def save_checkpoint(self, hf_models, output_dirs):
        """Saves model when validation loss decrease."""
        # TODO

        llm_save_dir, retriever_save_dir = output_dirs

        if not os.path.exists(llm_save_dir):
            os.makedirs(llm_save_dir)
        if not os.path.exists(retriever_save_dir):
            os.makedirs(retriever_save_dir)

        tokenizer, model, retriever = hf_models
        tokenizer.save_pretrained(llm_save_dir)
        model.save_pretrained(llm_save_dir)
        retriever.model.save(retriever_save_dir)
        # print(f"Best LLM saved to {llm_save_dir}")
        # print(f"Best Retriever saved to {retriever_save_dir}")

if __name__ == "__main__":
    terms = ["Probabilistic Methods", "Neural Networks"]
    sentence = "Among the candidates: (Neural Networks), (Probabilistic Methods), the correct label is: Probabilistic Methods"
    output = matching_order(terms, sentence, rev=True)
    print(output)